/** Copyright 2020-2023 Alibaba Group Holding Limited.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#ifndef MODULES_BASIC_DS_TENSOR_VINEYARD_MOD_
#define MODULES_BASIC_DS_TENSOR_VINEYARD_MOD_

#include <memory>
#include <string>
#include <vector>

#include "arrow/api.h"     // IWYU pragma: keep
#include "arrow/io/api.h"  // IWYU pragma: keep

#include "basic/ds/array.vineyard.h"  // IWYU pragma: keep
#include "basic/ds/arrow.h"           // IWYU pragma: keep
#include "basic/ds/arrow.vineyard.h"
#include "basic/ds/types.h"
#include "client/client.h"
#include "client/ds/blob.h"
#include "client/ds/i_object.h"

namespace vineyard {

#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wattributes"
#endif

namespace detail {

template <typename T>
struct ArrowTensorType {
  using type = arrow::NumericTensor<typename arrow::CTypeTraits<T>::ArrowType>;
};

template <>
struct ArrowTensorType<arrow_string_view> {
  using type = void;
};

}  // namespace detail

template <typename T>
class TensorBaseBuilder;

class ITensor : public Object {
 public:
  [[shared]] virtual std::vector<int64_t> const& shape() const = 0;

  [[shared]] virtual std::vector<int64_t> const& partition_index() const = 0;

  [[shared]] virtual AnyType value_type() const = 0;

  [[shared]] virtual const std::shared_ptr<arrow::Buffer> buffer() const = 0;

  [[shared]] virtual const std::shared_ptr<arrow::Buffer> auxiliary_buffer()
      const = 0;
};

template <typename T>
class [[vineyard]] Tensor : public ITensor, public BareRegistered<Tensor<T>> {
 public:
  using value_t = T;
  using value_pointer_t = T*;
  using value_const_pointer_t = const T*;
  using ArrowTensorT = typename detail::ArrowTensorType<T>::type;

  /**
   * @brief Get the strides of the tensor.
   *
   * @return The strides of the tensor. The definition of the tensor's strides
   * can be found in https://pytorch.org/docs/stable/tensor_attributes.html
   */
  [[shared]] std::vector<int64_t> strides() const {
    std::vector<int64_t> vec(shape_.size());
    vec[shape_.size() - 1] = sizeof(T);
    for (size_t i = shape_.size() - 1; i > 0; --i) {
      vec[i - 1] = vec[i] * shape_[i];
    }
    return vec;
  }

  /**
   * @brief Get the shape of the tensor.
   *
   * @return The shape vector where the ith element represents
   * the size of the ith axis.
   */
  [[shared]] std::vector<int64_t> const& shape() const override {
    return shape_;
  }

  /**
   * @brief Get the index of this partition in the global tensor.
   *
   * @return The index vector where the ith element represents the index
   * in the ith axis.
   */
  [[shared]] std::vector<int64_t> const& partition_index() const override {
    return partition_index_;
  }

  /**
   * @brief Get the type of tensor's elements.
   *
   * @return The type of the tensor's elements.
   */
  [[shared]] AnyType value_type() const override { return this->value_type_; }

  /**
   * @brief Get the data pointer to the tensor's data buffer.
   *
   * @return The data pointer.
   */
  [[shared]] value_const_pointer_t data() const {
    return reinterpret_cast<const T*>(buffer_->data());
  }

  /**
   * @brief Get the data in the tensor by index.
   *
   * @return The data reference.
   */
  [[shared]] const value_t operator[](size_t index) const {
    return this->data()[index];
  }

  /**
   * @brief Get the buffer of the tensor.
   *
   * @return The shared pointer to an arrow buffer which
   * holds the data buffer of the tensor.
   */
  [[shared]] const std::shared_ptr<arrow::Buffer> buffer() const override {
    return this->buffer_->ArrowBuffer();
  }

  [[shared]] const std::shared_ptr<arrow::Buffer> auxiliary_buffer()
      const override {
    return nullptr;
  }

  /**
   * @brief Return a view of the original tensor so that it can be used as
   * arrow's Tensor.
   *
   */
  [[shared]] const std::shared_ptr<ArrowTensorT> ArrowTensor() {
    return std::make_shared<ArrowTensorT>(buffer_->ArrowBuffer(), shape());
  }

 private:
  [[shared]] AnyType value_type_;
  [[shared]] std::shared_ptr<Blob> buffer_;
  [[shared]] Tuple<int64_t> shape_;
  [[shared]] Tuple<int64_t> partition_index_;

  friend class Client;
  friend class TensorBaseBuilder<T>;
};

template <>
class [[vineyard]] Tensor<std::string>
    : public ITensor, public BareRegistered<Tensor<std::string>> {
 public:
  using value_t = arrow_string_view;
  using value_pointer_t = uint8_t*;
  using value_const_pointer_t = const uint8_t*;
  using ArrowTensorT =
      typename detail::ArrowTensorType<arrow_string_view>::type;

  /**
   * @brief Get the strides of the tensor.
   *
   * @return The strides of the tensor. The definition of the tensor's strides
   * can be found in https://pytorch.org/docs/stable/tensor_attributes.html
   */
  [[shared]] std::vector<int64_t> strides() const {
    std::vector<int64_t> vec(shape_.size());
    vec[shape_.size() - 1] = 1 /* special case for tensors */;
    for (size_t i = shape_.size() - 1; i > 0; --i) {
      vec[i - 1] = vec[i] * shape_[i];
    }
    return vec;
  }

  /**
   * @brief Get the shape of the tensor.
   *
   * @return The shape vector where the ith element represents
   * the size of the ith axis.
   */
  [[shared]] std::vector<int64_t> const& shape() const override {
    return shape_;
  }

  /**
   * @brief Get the index of this partition in the global tensor.
   *
   * @return The index vector where the ith element represents the index
   * in the ith axis.
   */
  [[shared]] std::vector<int64_t> const& partition_index() const override {
    return partition_index_;
  }

  /**
   * @brief Get the type of tensor's elements.
   *
   * @return The type of the tensor's elements.
   */
  [[shared]] AnyType value_type() const override { return this->value_type_; }

  /**
   * @brief Get the data pointer to the tensor's data buffer.
   *
   * @return The data pointer.
   */
  [[shared]] value_const_pointer_t data() const {
    return this->buffer_->GetArray()->raw_data();
  }

  /**
   * @brief Get the data in the tensor by index.
   *
   * @return The data reference.
   */
  [[shared]] const value_t operator[](size_t index) const {
    return this->buffer_->GetArray()->GetView(index);
  }

  /**
   * @brief Get the buffer of the tensor.
   *
   * @return The shared pointer to an arrow buffer which
   * holds the data buffer of the tensor.
   */
  [[shared]] const std::shared_ptr<arrow::Buffer> buffer() const override {
    return this->buffer_->GetArray()->value_data();
  }

  /**
   * @brief Get the buffer of the tensor.
   *
   * @return The shared pointer to an arrow buffer which
   * holds the data buffer of the tensor.
   */
  [[shared]] const std::shared_ptr<arrow::Buffer> auxiliary_buffer()
      const override {
    return this->buffer_->GetArray()->value_data();
  }

  /**
   * @brief Return a view of the original tensor so that it can be used as
   * arrow's Tensor.
   *
   */
  [[shared]] const std::shared_ptr<ArrowTensorT> ArrowTensor() {
    // No corresponding arrow tensor type for std::string.
    return nullptr;
  }

 private:
  [[shared]] AnyType value_type_;
  [[shared]] std::shared_ptr<LargeStringArray> buffer_;
  [[shared]] Tuple<int64_t> shape_;
  [[shared]] Tuple<int64_t> partition_index_;

  friend class Client;
  friend class TensorBaseBuilder<std::string>;
};

#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif

}  // namespace vineyard

#endif  // MODULES_BASIC_DS_TENSOR_VINEYARD_MOD_

// vim: syntax=cpp
